import os
import pandas as pd
import numpy as np
import random
import math
import argparse
import configparser
from typing import List, Tuple, Dict, Any
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def read_mouse_data(file_path):
    """
    从 CSV 文件中读取小鼠数据
    :param file_path: CSV 文件路径
    :return: 包含小鼠数据的 DataFrame
    """
    return pd.read_csv(file_path)

def calculate_group_stats(groups: List[List[int]], data: pd.DataFrame) -> List[Tuple[float, float, float, float]]:
    """
    计算每个组的肿瘤体积和体重的平均值和标准差
    :param groups: 分组列表，每个元素是该组的索引列表
    :param data: 小鼠数据 DataFrame
    :return: 每个组的肿瘤体积和体重的平均值和标准差的元组列表
    """
    stats = []
    for group_indices in groups:
        if not group_indices:  # 空组处理
            stats.append((0.0, 0.0, 0.0, 0.0))
            continue
            
        group_data = data.iloc[group_indices]
        volume_mean = float(group_data['tumor_volume'].mean())
        volume_std = float(group_data['tumor_volume'].std())
        weight_mean = float(group_data['weight'].mean())
        weight_std = float(group_data['weight'].std())
        stats.append((volume_mean, volume_std, weight_mean, weight_std))
    return stats

def objective_function(stats: List[Tuple[float, float, float, float]], 
                      overall_volume_mean: float, overall_volume_std: float,
                      overall_weight_mean: float, overall_weight_std: float) -> float:
    """
    计算目标函数值，目标是使各组之间的肿瘤体积和体重的平均值和标准差与总体值差异最小
    
    目标函数计算各组统计量与总体统计量的平方差之和，值越小表示分组越均衡
    
    :param stats: 每个组的肿瘤体积和体重的平均值和标准差的元组列表
    :param overall_volume_mean: 所有小鼠肿瘤体积的平均值
    :param overall_volume_std: 所有小鼠肿瘤体积的标准差
    :param overall_weight_mean: 所有小鼠体重的平均值
    :param overall_weight_std: 所有小鼠体重的标准差
    :return: 目标函数值，数值越小表示分组越均衡
    """
    if not stats:
        return float('inf')
    
    # 使用NumPy数组进行向量化计算
    stats_array = np.array(stats)
    volume_means = stats_array[:, 0]
    volume_stds = stats_array[:, 1]
    weight_means = stats_array[:, 2]
    weight_stds = stats_array[:, 3]

    # 计算各项差异的平方和
    volume_mean_diff = np.sum((volume_means - overall_volume_mean) ** 2)
    volume_std_diff = np.sum((volume_stds - overall_volume_std) ** 2)
    weight_mean_diff = np.sum((weight_means - overall_weight_mean) ** 2)
    weight_std_diff = np.sum((weight_stds - overall_weight_std) ** 2)

    return float(volume_mean_diff + volume_std_diff + weight_mean_diff + weight_std_diff)

def simulated_annealing(data: pd.DataFrame, num_groups: int, initial_temp: float, 
                       final_temp: float, alpha: float, max_iter: int,
                       overall_volume_mean: float, overall_volume_std: float,
                       overall_weight_mean: float, overall_weight_std: float) -> List[List[int]]:
    """
    模拟退火算法优化分组方案
    
    该算法通过随机交换小鼠分组来寻找最优解，使用温度参数控制接受劣解的概率，
    逐步降低温度以收敛到最优解
    
    :param data: 小鼠数据 DataFrame
    :param num_groups: 分组数量
    :param initial_temp: 初始温度
    :param final_temp: 最终温度
    :param alpha: 降温系数
    :param max_iter: 最大迭代次数
    :param overall_volume_mean: 所有小鼠肿瘤体积的平均值
    :param overall_volume_std: 所有小鼠肿瘤体积的标准差
    :param overall_weight_mean: 所有小鼠体重的平均值
    :param overall_weight_std: 所有小鼠体重的标准差
    :return: 最优分组方案，每个元素是该组的索引列表
    :raises ValueError: 当参数无效时
    """
    if num_groups <= 0:
        raise ValueError("分组数量必须大于0")
    if initial_temp <= 0 or final_temp <= 0 or alpha <= 0 or alpha >= 1:
        raise ValueError("温度参数和降温系数必须在有效范围内")
    if max_iter <= 0:
        raise ValueError("最大迭代次数必须大于0")
    
    num_mice = len(data)
    if num_mice < num_groups:
        raise ValueError("小鼠数量必须大于分组数量")
    
    group_size = num_mice // num_groups
    remaining = num_mice % num_groups

    # 初始分组 - 均匀分配
    groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)]

    # 处理剩余小鼠 - 分配到使目标函数最小的组
    remaining_mice = list(range(num_mice - remaining, num_mice))
    for mouse in remaining_mice:
        best_group_index = _find_best_group_for_mouse(groups, mouse, data, 
                                                     overall_volume_mean, overall_volume_std,
                                                     overall_weight_mean, overall_weight_std)
        groups[best_group_index].append(mouse)

    # 初始化当前解和最优解
    current_stats = calculate_group_stats(groups, data)
    current_obj = objective_function(current_stats, overall_volume_mean, overall_volume_std,
                                   overall_weight_mean, overall_weight_std)
    best_groups = [group.copy() for group in groups]
    best_obj = current_obj
    temp = initial_temp

    # 模拟退火主循环
    for iteration in range(max_iter):
        # 生成新解：随机交换两只小鼠的分组
        new_groups = [group.copy() for group in groups]
        group1, group2 = random.sample(range(num_groups), 2)
        
        if new_groups[group1] and new_groups[group2]:  # 确保两个组都有小鼠
            index1 = random.randint(0, len(new_groups[group1]) - 1)
            index2 = random.randint(0, len(new_groups[group2]) - 1)
            new_groups[group1][index1], new_groups[group2][index2] = \
                new_groups[group2][index2], new_groups[group1][index1]

            new_stats = calculate_group_stats(new_groups, data)
            new_obj = objective_function(new_stats, overall_volume_mean, overall_volume_std,
                                       overall_weight_mean, overall_weight_std)

            # Metropolis准则判断是否接受新解
            if new_obj < current_obj or random.random() < math.exp((current_obj - new_obj) / temp):
                groups = new_groups
                current_obj = new_obj

            # 更新最优解
            if new_obj < best_obj:
                best_groups = [group.copy() for group in new_groups]
                best_obj = new_obj

        # 降温
        temp *= alpha
        if temp < final_temp:
            logging.info(f"温度降至 {temp:.6f}，提前终止迭代 (迭代次数: {iteration + 1})")
            break
    
    logging.info(f"模拟退火完成，最优目标函数值: {best_obj:.6f}")
    return best_groups

def _find_best_group_for_mouse(groups: List[List[int]], mouse: int, data: pd.DataFrame,
                              overall_volume_mean: float, overall_volume_std: float,
                              overall_weight_mean: float, overall_weight_std: float) -> int:
    """
    为新小鼠找到最佳的分组
    
    :param groups: 当前的分组方案
    :param mouse: 新小鼠的索引
    :param data: 小鼠数据 DataFrame
    :param overall_volume_mean: 所有小鼠肿瘤体积的平均值
    :param overall_volume_std: 所有小鼠肿瘤体积的标准差
    :param overall_weight_mean: 所有小鼠体重的平均值
    :param overall_weight_std: 所有小鼠体重的标准差
    :return: 最佳分组索引
    """
    best_obj = float('inf')
    best_group_index = 0
    
    for group_index in range(len(groups)):
        temp_groups = [group.copy() for group in groups]
        temp_groups[group_index].append(mouse)
        temp_stats = calculate_group_stats(temp_groups, data)
        temp_obj = objective_function(temp_stats, overall_volume_mean, overall_volume_std,
                                   overall_weight_mean, overall_weight_std)
        if temp_obj < best_obj:
            best_obj = temp_obj
            best_group_index = group_index
    
    return best_group_index

def read_last_grouping(file_path: str) -> Tuple[List[List[int]], pd.DataFrame]:
    """
    从 CSV 文件中读取上一次的分组结果
    :param file_path: 包含分组结果的 CSV 文件路径
    :return: 上一次的分组列表和原始数据
    """
    data_with_group = pd.read_csv(file_path)
    groups = [[] for _ in range(data_with_group['group'].max())]
    for index, row in data_with_group.iterrows():
        group_index = int(row['group']) - 1
        groups[group_index].append(index)
    return groups, data_with_group.drop('group', axis=1)

def load_config(config_file: str, section: str) -> Dict[str, Any]:
    """
    加载配置文件
    
    :param config_file: 配置文件路径
    :param section: 配置节名
    :return: 配置字典
    :raises FileNotFoundError: 当配置文件不存在时
    :raises ValueError: 当配置格式不正确时
    """
    if not os.path.exists(config_file):
        raise FileNotFoundError(f"配置文件未找到: {config_file}")
    
    config = configparser.ConfigParser()
    config.read(config_file, encoding='utf-8')
    
    if section not in config:
        raise ValueError(f"配置文件中缺少节: {section}")
    
    return dict(config[section])

def save_results(data: pd.DataFrame, groups: List[List[int]], result_folder: str, run_number: int,
                overall_stats: Dict[str, float]) -> None:
    """
    保存分组结果和统计信息
    
    :param data: 小鼠数据
    :param groups: 分组方案
    :param result_folder: 结果文件夹
    :param run_number: 运行编号
    :param overall_stats: 总体统计信息
    """
    os.makedirs(result_folder, exist_ok=True)
    
    # 生成分组数据文件
    csv_file_name = os.path.join(result_folder, f'grouped_mouse_data_run{run_number}.csv')
    stats_csv_name = os.path.join(result_folder, f'group_stats_run{run_number}.csv')
    
    # 写入分组数据
    data_with_group = data.copy()
    data_with_group['group'] = None
    for i, group in enumerate(groups):
        data_with_group.loc[group, 'group'] = i + 1
    data_with_group.to_csv(csv_file_name, index=False)
    
    # 写入统计信息
    stats_data = []
    group_stats = []  # 用于计算目标函数值
    
    for i, group in enumerate(groups):
        group_data = data.iloc[group]
        volume_mean = float(group_data['tumor_volume'].mean())
        volume_std = float(group_data['tumor_volume'].std())
        weight_mean = float(group_data['weight'].mean())
        weight_std = float(group_data['weight'].std())
        
        group_stats.append((volume_mean, volume_std, weight_mean, weight_std))
        
        stats_data.append({
            'Group': i + 1,
            'Mean_tumor_volume': volume_mean,
            'Std_tumor_volume': volume_std,
            'Mean_weight': weight_mean,
            'Std_weight': weight_std,
            'Count': len(group)
        })
    
    # 计算目标函数值
    objective_value = objective_function(
        group_stats, 
        overall_stats['volume_mean'], 
        overall_stats['volume_std'],
        overall_stats['weight_mean'], 
        overall_stats['weight_std']
    )
    
    # 将目标函数值添加到每个组的统计信息中
    for group_data in stats_data:
        group_data['Objective_value'] = objective_value
    
    # 添加总体统计
    stats_data.append({
        'Group': 'Overall',
        'Mean_tumor_volume': overall_stats['volume_mean'],
        'Std_tumor_volume': overall_stats['volume_std'],
        'Mean_weight': overall_stats['weight_mean'],
        'Std_weight': overall_stats['weight_std'],
        'Count': len(data),
        'Objective_value': objective_value
    })
    
    pd.DataFrame(stats_data).to_csv(stats_csv_name, index=False)
    logging.info(f"结果已保存到: {result_folder}, 目标函数值: {objective_value:.6f}")

def main():
    """
    主函数：执行小鼠分组任务
    """
    parser = argparse.ArgumentParser(description='为小鼠肿瘤模型自动分组')
    parser.add_argument('input_file', type=str, help='原始数据 CSV 文件的文件名')
    parser.add_argument('--last_grouping', type=str, help='上一次分组结果的 CSV 文件路径')
    parser.add_argument('--config_dir', type=str, default='.', help='配置文件目录')
    args = parser.parse_args()

    try:
        # 加载配置
        grouping_config = load_config(os.path.join(args.config_dir, 'group.ini'), 'grouping')
        sa_config = load_config(os.path.join(args.config_dir, 'config.ini'), 'simulated_annealing')
        
        # 解析配置参数
        num_groups = int(grouping_config['num_groups'])
        group_sizes = [int(size) for size in grouping_config['group_sizes'].split(',')]
        runs = int(grouping_config['runs'])
        
        sa_params = {
            'initial_temp': float(sa_config['initial_temp']),
            'final_temp': float(sa_config['final_temp']),
            'alpha': float(sa_config['alpha']),
            'max_iter': int(sa_config['max_iter'])
        }
        
        # 验证配置
        if len(group_sizes) != num_groups:
            raise ValueError("分组数量和每组动物数量不匹配")
        
        # 读取数据
        logging.info(f"读取数据文件: {args.input_file}")
        new_data = read_mouse_data(args.input_file)
        logging.info(f"读取到 {len(new_data)} 只小鼠的数据")

        # 处理历史分组
        if args.last_grouping:
            logging.info(f"读取历史分组: {args.last_grouping}")
            last_groups, last_data = read_last_grouping(args.last_grouping)
            data = pd.concat([last_data, new_data], ignore_index=True)
            offset = len(last_data)
            new_mice_indices = list(range(offset, len(data)))
            logging.info(f"新增 {len(new_mice_indices)} 只小鼠")
        else:
            data = new_data
            last_groups = []
            new_mice_indices = list(range(len(data)))

        # 计算总体统计量
        overall_stats = {
            'volume_mean': float(data['tumor_volume'].mean()),
            'volume_std': float(data['tumor_volume'].std()),
            'weight_mean': float(data['weight'].mean()),
            'weight_std': float(data['weight'].std())
        }
        
        logging.info(f"总体统计 - 肿瘤体积: {overall_stats['volume_mean']:.2f}±{overall_stats['volume_std']:.2f}, "
                    f"体重: {overall_stats['weight_mean']:.2f}±{overall_stats['weight_std']:.2f}")

        # 创建结果文件夹
        file_name_without_ext = os.path.splitext(os.path.basename(args.input_file))[0]
        result_folder = file_name_without_ext

        # 执行分组
        for run in range(runs):
            logging.info(f"开始第 {run + 1} 次分组运行")
            
            if args.last_grouping:
                # 保持上一次分组，只对新小鼠进行分组
                groups = [group.copy() for group in last_groups]
                group_counts = [len(group) for group in groups]
                
                for mouse in new_mice_indices:
                    min_count = min(group_counts)
                    candidate_groups = [i for i, count in enumerate(group_counts) 
                                      if count == min_count]
                    
                    if len(candidate_groups) == 1:
                        best_group_index = candidate_groups[0]
                    else:
                        best_group_index = _find_best_group_for_mouse(
                            groups, mouse, data, overall_stats['volume_mean'], 
                            overall_stats['volume_std'], overall_stats['weight_mean'], 
                            overall_stats['weight_std'])
                    
                    groups[best_group_index].append(mouse)
                    group_counts[best_group_index] += 1
            else:
                # 全新分组
                groups = simulated_annealing(
                    data, num_groups, sa_params['initial_temp'], sa_params['final_temp'],
                    sa_params['alpha'], sa_params['max_iter'], overall_stats['volume_mean'],
                    overall_stats['volume_std'], overall_stats['weight_mean'], 
                    overall_stats['weight_std'])

            # 保存结果
            save_results(data, groups, result_folder, run + 1, overall_stats)
            
            # 验证分组结果
            total_mice = sum(len(group) for group in groups)
            if total_mice != len(data):
                logging.warning(f"分组小鼠总数({total_mice})与原始数据({len(data)})不匹配")

        logging.info(f"所有 {runs} 次分组运行完成")

    except Exception as e:
        logging.error(f"执行失败: {e}")
        raise

if __name__ == "__main__":
    main()